{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "punL79CN7Ox6"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "_ckMIh7O7s6D"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ph5eir3Pf-3z"
      },
      "source": [
        "# Optimizing the Text Generation Model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S5Uhzt6vVIB2"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/examples/blob/master/courses/udacity_intro_to_tensorflow_for_deep_learning/l10c04_nlp_optimizing_the_text_generation_model.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/examples/blob/master/courses/udacity_intro_to_tensorflow_for_deep_learning/l10c04_nlp_optimizing_the_text_generation_model.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dCxhW3mtLmfb"
      },
      "source": [
        "You've already done some amazing work with generating new songs, but so far we've seen some issues with repetition and a fair amount of incoherence. By using more data and further tweaking the model, you'll be able to get improved results. We'll once again use the [Kaggle Song Lyrics Dataset](https://www.kaggle.com/mousehead/songlyrics) here."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4aHK2CYygXom"
      },
      "source": [
        "## Import TensorFlow and related functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2LmLTREBf5ng"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "from tensorflow.keras.preprocessing.text import Tokenizer\n",
        "from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
        "\n",
        "# Other imports for processing data\n",
        "import string\n",
        "import numpy as np\n",
        "import pandas as pd"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GmLTO_dpgge9"
      },
      "source": [
        "## Get the Dataset\n",
        "\n",
        "As noted above, we'll utilize the [Song Lyrics dataset](https://www.kaggle.com/mousehead/songlyrics) on Kaggle again."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4Bf5FVHfganK"
      },
      "outputs": [],
      "source": [
        "!wget --no-check-certificate \\\n",
        "    https://drive.google.com/uc?id=1LiJFZd41ofrWoBtW-pMYsfz1w8Ny0Bj8 \\\n",
        "    -O /tmp/songdata.csv"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jz9x-7dWihxx"
      },
      "source": [
        "## 250 Songs\n",
        "\n",
        "Now we've seen a model trained on just a small sample of songs, and how this often leads to repetition as you get further along in trying to generate new text. Let's switch to using the 250 songs instead, and see if our output improves. This will actually be nearly 10K lines of lyrics, which should be sufficient.\n",
        "\n",
        "Note that we won't use the full dataset here as it will take up quite a bit of RAM and processing time, but you're welcome to try doing so on your own later. If interested, you'll likely want to use only some of the more common words for the Tokenizer, which will help shrink processing time and memory needed (or else you'd have an output array hundreds of thousands of words long)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nWbMN_19jfRT"
      },
      "source": [
        "### Preprocessing"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LRmPPJegovBe"
      },
      "outputs": [],
      "source": [
        "def tokenize_corpus(corpus, num_words=-1):\n",
        "  # Fit a Tokenizer on the corpus\n",
        "  if num_words > -1:\n",
        "    tokenizer = Tokenizer(num_words=num_words)\n",
        "  else:\n",
        "    tokenizer = Tokenizer()\n",
        "  tokenizer.fit_on_texts(corpus)\n",
        "  return tokenizer\n",
        "\n",
        "def create_lyrics_corpus(dataset, field):\n",
        "  # Remove all other punctuation\n",
        "  dataset[field] = dataset[field].str.replace('[{}]'.format(string.punctuation), '')\n",
        "  # Make it lowercase\n",
        "  dataset[field] = dataset[field].str.lower()\n",
        "  # Make it one long string to split by line\n",
        "  lyrics = dataset[field].str.cat()\n",
        "  corpus = lyrics.split('\\n')\n",
        "  # Remove any trailing whitespace\n",
        "  for l in range(len(corpus)):\n",
        "    corpus[l] = corpus[l].rstrip()\n",
        "  # Remove any empty lines\n",
        "  corpus = [l for l in corpus if l != '']\n",
        "\n",
        "  return corpus"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kIGedF3XjHj4"
      },
      "outputs": [],
      "source": [
        "def tokenize_corpus(corpus, num_words=-1):\n",
        "  # Fit a Tokenizer on the corpus\n",
        "  if num_words > -1:\n",
        "    tokenizer = Tokenizer(num_words=num_words)\n",
        "  else:\n",
        "    tokenizer = Tokenizer()\n",
        "  tokenizer.fit_on_texts(corpus)\n",
        "  return tokenizer\n",
        "\n",
        "# Read the dataset from csv - this time with 250 songs\n",
        "dataset = pd.read_csv('/tmp/songdata.csv', dtype=str)[:250]\n",
        "# Create the corpus using the 'text' column containing lyrics\n",
        "corpus = create_lyrics_corpus(dataset, 'text')\n",
        "# Tokenize the corpus\n",
        "tokenizer = tokenize_corpus(corpus, num_words=2000)\n",
        "total_words = tokenizer.num_words\n",
        "\n",
        "# There should be a lot more words now\n",
        "print(total_words)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "quoDmw_FkNBA"
      },
      "source": [
        "### Create Sequences and Labels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kkLAf3HmkPSo"
      },
      "outputs": [],
      "source": [
        "sequences = []\n",
        "for line in corpus:\n",
        "\ttoken_list = tokenizer.texts_to_sequences([line])[0]\n",
        "\tfor i in range(1, len(token_list)):\n",
        "\t\tn_gram_sequence = token_list[:i+1]\n",
        "\t\tsequences.append(n_gram_sequence)\n",
        "\n",
        "# Pad sequences for equal input length \n",
        "max_sequence_len = max([len(seq) for seq in sequences])\n",
        "sequences = np.array(pad_sequences(sequences, maxlen=max_sequence_len, padding='pre'))\n",
        "\n",
        "# Split sequences between the \"input\" sequence and \"output\" predicted word\n",
        "input_sequences, labels = sequences[:,:-1], sequences[:,-1]\n",
        "# One-hot encode the labels\n",
        "one_hot_labels = tf.keras.utils.to_categorical(labels, num_classes=total_words)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cECbqT-blMk-"
      },
      "source": [
        "### Train a (Better) Text Generation Model\n",
        "\n",
        "With more data, we'll cut off after 100 epochs to avoid keeping you here all day. You'll also want to change your runtime type to GPU if you haven't already (you'll need to re-run the above cells if you change runtimes)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7nHOp6uWlP_P"
      },
      "outputs": [],
      "source": [
        "from tensorflow.keras.models import Sequential\n",
        "from tensorflow.keras.layers import Embedding, LSTM, Dense, Bidirectional\n",
        "\n",
        "model = Sequential()\n",
        "model.add(Embedding(total_words, 64, input_length=max_sequence_len-1))\n",
        "model.add(Bidirectional(LSTM(20)))\n",
        "model.add(Dense(total_words, activation='softmax'))\n",
        "model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n",
        "history = model.fit(input_sequences, one_hot_labels, epochs=100, verbose=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MgvIz20nlQcq"
      },
      "source": [
        "### View the Training Graph"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rOqmmarvlSLh"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "def plot_graphs(history, string):\n",
        "  plt.plot(history.history[string])\n",
        "  plt.xlabel(\"Epochs\")\n",
        "  plt.ylabel(string)\n",
        "  plt.show()\n",
        "\n",
        "plot_graphs(history, 'accuracy')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ISLZZGlQlSxh"
      },
      "source": [
        "### Generate better lyrics!\n",
        "\n",
        "This time around, we should be able to get a more interesting output with less repetition."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P96oVMk3lU7y"
      },
      "outputs": [],
      "source": [
        "seed_text = \"im feeling chills\"\n",
        "next_words = 100\n",
        "  \n",
        "for _ in range(next_words):\n",
        "\ttoken_list = tokenizer.texts_to_sequences([seed_text])[0]\n",
        "\ttoken_list = pad_sequences([token_list], maxlen=max_sequence_len-1, padding='pre')\n",
        "\tpredicted = np.argmax(model.predict(token_list), axis=-1)\n",
        "\toutput_word = \"\"\n",
        "\tfor word, index in tokenizer.word_index.items():\n",
        "\t\tif index == predicted:\n",
        "\t\t\toutput_word = word\n",
        "\t\t\tbreak\n",
        "\tseed_text += \" \" + output_word\n",
        "print(seed_text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "upgJKV8_oRU9"
      },
      "source": [
        "### Varying the Possible Outputs\n",
        "\n",
        "In running the above, you may notice that the same seed text will generate similar outputs. This is because the code is currently always choosing the top predicted class as the next word. What if you wanted more variance in the output? \n",
        "\n",
        "Switching from `model.predict_classes` to `model.predict_proba` will get us all of the class probabilities. We can combine this with `np.random.choice` to select a given predicted output based on a probability, thereby giving a bit more randomness to our outputs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lZe9gaJeoGVP"
      },
      "outputs": [],
      "source": [
        "# Test the method with just the first word after the seed text\n",
        "seed_text = \"im feeling chills\"\n",
        "next_words = 100\n",
        "  \n",
        "token_list = tokenizer.texts_to_sequences([seed_text])[0]\n",
        "token_list = pad_sequences([token_list], maxlen=max_sequence_len-1, padding='pre')\n",
        "predicted_probs = model.predict(token_list)[0]\n",
        "predicted = np.random.choice([x for x in range(len(predicted_probs))], \n",
        "                             p=predicted_probs)\n",
        "# Running this cell multiple times should get you some variance in output\n",
        "print(predicted)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ee7WKgRGrJy1"
      },
      "outputs": [],
      "source": [
        "# Use this process for the full output generation\n",
        "seed_text = \"im feeling chills\"\n",
        "next_words = 100\n",
        "  \n",
        "for _ in range(next_words):\n",
        "  token_list = tokenizer.texts_to_sequences([seed_text])[0]\n",
        "  token_list = pad_sequences([token_list], maxlen=max_sequence_len-1, padding='pre')\n",
        "  predicted_probs = model.predict(token_list)[0]\n",
        "  predicted = np.random.choice([x for x in range(len(predicted_probs))],\n",
        "                               p=predicted_probs)\n",
        "  output_word = \"\"\n",
        "  for word, index in tokenizer.word_index.items():\n",
        "    if index == predicted:\n",
        "      output_word = word\n",
        "      break\n",
        "  seed_text += \" \" + output_word\n",
        "print(seed_text)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "l10c04_nlp_optimizing_the_text_generation_model.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
